自定义模型¶
Note
自定义模型就像自定义层,通过继承keras.Model来实现我们想要的功能。
残差网络¶
我们通过实现残差网络来进行阐述如何自定义模型。
keras.Model是keras.layers.Layer的子类,因此可以像自定义层一样定义和使用模型。
但是模型有些额外的功能,包括compile()、fit()、evaluate()、predict()方法。
from tensorflow import keras
class ResidualBlock(keras.layers.Layer):
# 自定义残差块
def __init__(self, n_layers, n_neurons, **kwargs):
# kwargs handles standard args (e.g., name)
super().__init__(**kwargs)
self.hidden = [keras.layers.Dense(n_neurons, activation='relu')
for _ in range(n_layers)]
def call(self, inputs):
Z = inputs
for layer in self.hidden:
Z = layer(Z)
return inputs + Z
class ResidualRegressor(keras.models.Model):
# 自定义残差网络
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.hidden1 = keras.layers.Dense(30, activation="relu")
# 两个残差块
self.block1 = ResidualBlock(2, 30)
self.block2 = ResidualBlock(2, 30)
self.out = keras.layers.Dense(output_dim)
def call(self, inputs):
Z = self.hidden1(inputs)
# 重复残差两次
for _ in range(2):
Z = self.block1(Z)
Z = self.block2(Z)
return self.out(Z)
Note
同自定义层一样,与input_shape相关的initialization需在build()中实现